{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Writing guides using EasyGuide\n",
    "\n",
    "This tutorial describes the [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html) module. This tutorial assumes the reader is already familiar with [SVI](http://pyro.ai/examples/svi_part_ii.html) and [tensor shapes](http://pyro.ai/examples/tensor_shapes.html).\n",
    "\n",
    "#### Summary\n",
    "\n",
    "- For simple black-box guides, try using components in [pyro.infer.autoguide](http://docs.pyro.ai/en/stable/infer.autoguide.html).\n",
    "- For more complex guides, try using components in [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html).\n",
    "- Decorate with `@easy_guide(model)`.\n",
    "- Select multiple model sites using `group = self.group(match=\"my_regex\")`.\n",
    "- Guide a group of sites by a single distribution using `group.sample(...)`.\n",
    "- Inspect concatenated group shape using `group.batch_shape`, `group.event_shape`, etc.\n",
    "- Use `self.plate(...)` instead of `pyro.plate(...)`.\n",
    "- To be compatible with subsampling, pass the `event_dim` arg to `pyro.param(...)`.\n",
    "- To MAP estimate model site \"foo\", use `foo = self.map_estimate(\"foo\")`.\n",
    "\n",
    "#### Table of contents\n",
    "\n",
    "- [Modeling time series data](#Modeling-time-series-data)\n",
    "- [Writing a guide without EasyGuide](#Writing-a-guide-without-EasyGuide)\n",
    "- [Using EasyGuide](#Using-EasyGuide)\n",
    "- [Amortized guides](#Amortized-guides)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "from pyro.infer import SVI, Trace_ELBO\n",
    "from pyro.contrib.easyguide import easy_guide\n",
    "from pyro.optim import Adam\n",
    "from torch.distributions import constraints\n",
    "\n",
    "pyro.enable_validation(True)\n",
    "smoke_test = ('CI' in os.environ)\n",
    "assert pyro.__version__.startswith('1.5.1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling time series data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Consider a time-series model with a slowly-varying continuous latent state and Bernoulli observations with a logistic link function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(batch, subsample, full_size):\n",
    "    batch = list(batch)\n",
    "    num_time_steps = len(batch)\n",
    "    drift = pyro.sample(\"drift\", dist.LogNormal(-1, 0.5))\n",
    "    with pyro.plate(\"data\", full_size, subsample=subsample):\n",
    "        z = 0.\n",
    "        for t in range(num_time_steps):\n",
    "            z = pyro.sample(\"state_{}\".format(t),\n",
    "                            dist.Normal(z, drift))\n",
    "            batch[t] = pyro.sample(\"obs_{}\".format(t),\n",
    "                                   dist.Bernoulli(logits=z),\n",
    "                                   obs=batch[t])\n",
    "    return torch.stack(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's generate some data directly from the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_size = 100\n",
    "num_time_steps = 7\n",
    "pyro.set_rng_seed(123456789)\n",
    "data = model([None] * num_time_steps, torch.arange(full_size), full_size)\n",
    "assert data.shape == (num_time_steps, full_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing a guide without EasyGuide\n",
    "\n",
    "Consider a possible guide for this model where we point-estimate the `drift` parameter using a `Delta` distribution, and then model local time series using shared uncertainty but local means, using a `LowRankMultivariateNormal` distribution. There is a single global sample site which we can model with a `param` and `sample` statement. Then we sample a global pair of uncertainty parameters `cov_diag` and `cov_factor`. Next we sample a local `loc` parameter using `pyro.param(..., event_dim=...)` and an auxiliary sample site. Finally we unpack that auxiliary site into one element per time series. The auxiliary-unpacked-to-`Delta`s pattern is quite common."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 3\n",
    "    \n",
    "def guide(batch, subsample, full_size):\n",
    "    num_time_steps, batch_size = batch.shape\n",
    "\n",
    "    # MAP estimate the drift.\n",
    "    drift_loc = pyro.param(\"drift_loc\", lambda: torch.tensor(0.1),\n",
    "                           constraint=constraints.positive)\n",
    "    pyro.sample(\"drift\", dist.Delta(drift_loc))\n",
    "\n",
    "    # Model local states using shared uncertainty + local mean.\n",
    "    cov_diag = pyro.param(\"state_cov_diag\",\n",
    "                          lambda: torch.full((num_time_steps,), 0.01),\n",
    "                         constraint=constraints.positive)\n",
    "    cov_factor = pyro.param(\"state_cov_factor\",\n",
    "                            lambda: torch.randn(num_time_steps, rank) * 0.01)\n",
    "    with pyro.plate(\"data\", full_size, subsample=subsample):\n",
    "        # Sample local mean.\n",
    "        loc = pyro.param(\"state_loc\",\n",
    "                         lambda: torch.full((full_size, num_time_steps), 0.5),\n",
    "                         event_dim=1)\n",
    "        states = pyro.sample(\"states\",\n",
    "                             dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),\n",
    "                             infer={\"is_auxiliary\": True})\n",
    "        # Unpack the joint states into one sample site per time step.\n",
    "        for t in range(num_time_steps):\n",
    "            pyro.sample(\"state_{}\".format(t), dist.Delta(states[:, t]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's train using [SVI](http://docs.pyro.ai/en/stable/inference_algos.html#module-pyro.infer.svi) and [Trace_ELBO](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.trace_elbo.Trace_ELBO), manually batching data into small minibatches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):\n",
    "    full_size = data.size(-1)\n",
    "    pyro.get_param_store().clear()\n",
    "    pyro.set_rng_seed(123456789)\n",
    "    svi = SVI(model, guide, Adam({\"lr\": 0.02}), Trace_ELBO())\n",
    "    for epoch in range(num_epochs):\n",
    "        pos = 0\n",
    "        losses = []\n",
    "        while pos < full_size:\n",
    "            subsample = torch.arange(pos, pos + batch_size)\n",
    "            batch = data[:, pos:pos + batch_size]\n",
    "            pos += batch_size\n",
    "            losses.append(svi.step(batch, subsample, full_size=full_size))\n",
    "        epoch_loss = sum(losses) / len(losses)\n",
    "        if epoch % 10 == 0:\n",
    "            print(\"epoch {} loss = {}\".format(epoch, epoch_loss / data.numel()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(guide)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using EasyGuide\n",
    "\n",
    "Now let's simplify using the `@easy_guide` decorator. Our modifications are:\n",
    "1. Decorate with `@easy_guide` and add `self` to args.\n",
    "2. Replace the `Delta` guide for drift with a simple `map_estimate()`.\n",
    "3. Select a `group` of model sites and read their concatenated `event_shape`.\n",
    "4. Replace the auxiliary site and `Delta` slices with a single `group.sample()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@easy_guide(model)\n",
    "def guide(self, batch, subsample, full_size):\n",
    "    # MAP estimate the drift.\n",
    "    self.map_estimate(\"drift\")\n",
    "\n",
    "    # Model local states using shared uncertainty + local mean.\n",
    "    group = self.group(match=\"state_[0-9]*\")  # Selects all local variables.\n",
    "    cov_diag = pyro.param(\"state_cov_diag\",\n",
    "                          lambda: torch.full(group.event_shape, 0.01),\n",
    "                          constraint=constraints.positive)\n",
    "    cov_factor = pyro.param(\"state_cov_factor\",\n",
    "                            lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n",
    "    with self.plate(\"data\", full_size, subsample=subsample):\n",
    "        # Sample local mean.\n",
    "        loc = pyro.param(\"state_loc\",\n",
    "                         lambda: torch.full((full_size,) + group.event_shape, 0.5),\n",
    "                         event_dim=1)\n",
    "        # Automatically sample the joint latent, then unpack and replay model sites.\n",
    "        group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note we've used `group.event_shape` to determine the total flattened concatenated shape of all matched sites in the group."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(guide)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Amortized guides\n",
    "\n",
    "`EasyGuide` also makes it easy to write amortized guides (guides where we learn a function that predicts latent variables from data, rather than learning one parameter per datapoint). Let's modify the last guide to predict the latent `loc` as an affine function of observed data, rather than memorizing each data point's latent variable. This amortized guide is more useful in practice because it can handle new data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@easy_guide(model)\n",
    "def guide(self, batch, subsample, full_size):\n",
    "    num_time_steps, batch_size = batch.shape\n",
    "    self.map_estimate(\"drift\")\n",
    "\n",
    "    group = self.group(match=\"state_[0-9]*\")\n",
    "    cov_diag = pyro.param(\"state_cov_diag\",\n",
    "                          lambda: torch.full(group.event_shape, 0.01),\n",
    "                          constraint=constraints.positive)\n",
    "    cov_factor = pyro.param(\"state_cov_factor\",\n",
    "                            lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n",
    "\n",
    "    # Predict latent propensity as an affine function of observed data.\n",
    "    if not hasattr(self, \"nn\"):\n",
    "        self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())\n",
    "        self.nn.weight.data.fill_(1.0 / num_time_steps)\n",
    "        self.nn.bias.data.fill_(-0.5)\n",
    "    pyro.module(\"state_nn\", self.nn)\n",
    "    with self.plate(\"data\", full_size, subsample=subsample):\n",
    "        loc = self.nn(batch.t())\n",
    "        group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(guide)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
